# coding=utf-8

from sqlalchemy import func

from common import orm
from common.channel import admin_db as channel_db
from common.mch import db as mch_db
from common.order.model import PayOrder, PAY_STATUS, ManualOrder
from common.utils import track_logging
from common.utils import tz
from common.utils.db import list_object, get, upsert, delete, generate_filter
from common.utils.decorator import sql_wrapper

_LOGGER = track_logging.getLogger(__name__)


@sql_wrapper
def get_order(id):
    return get(PayOrder, id)


@sql_wrapper
def upsert_order(info, id=None):
    return upsert(PayOrder, info, id)


@sql_wrapper
def list_order(query_dct):
    return list_object(query_dct, PayOrder)


@sql_wrapper
def list_manual_order(query_dct):
    return list_object(query_dct, ManualOrder)


@sql_wrapper
def get_manual_count_by_id(channel_id):
    return ManualOrder.query.filter(ManualOrder.channel_id == channel_id).count()


MAX_EXPORT_NUMBER = '5000'  # 注意这里只能为字符串


@sql_wrapper
def export_order(query_dct):
    query_dct.update({'$size': MAX_EXPORT_NUMBER})

    items, _ = list_object(query_dct, PayOrder)
    resp_items = []
    for item in items:
        created_at = tz.utc_to_local_str(item.created_at)
        payed_at = tz.utc_to_local_str(item.payed_at)
        channel_obj = channel_db.get_channel(item.channel_id)
        channel_obj_name = channel_obj.name if channel_obj else u'渠道已被删除'
        status = PAY_STATUS.get_label(item.status)
        data_row = [('JP' + str(item.id)), ('JP' + str(item.out_trade_no)), item.user_id or '-', item.third_id or '-',
                    created_at, item.mch_id, item.channel_id, channel_obj_name,
                    item.total_fee, payed_at or '-', status]
        resp_items.append(data_row)
    return resp_items


@sql_wrapper
def export_reconciliation_order(query_dct):
    YES = u'是'
    NO = u'否'
    DEFAULT_DASH_VALUE = '-'
    REASON = 'reason'
    ORDER_NO_PREFIX = 'D_'
    COMPANY_ORDER_NO_PREFIX = 'T_'

    orders, _ = list_object(query_dct, PayOrder, disable_paginate=True)
    mch_list = mch_db.get_mch_dct()
    resp_items = []

    for order in orders:
        order_id = ORDER_NO_PREFIX + str(order.id)
        out_trade_no = COMPANY_ORDER_NO_PREFIX + str(order.out_trade_no)
        third_id = order.third_id or DEFAULT_DASH_VALUE
        created_at = tz.utc_to_local_str(order.created_at)
        mch_name = mch_list[order.mch_id]
        total_fee = float(order.total_fee)
        payed_at = DEFAULT_DASH_VALUE if not order.payed_at else tz.utc_to_local_str(
            order.payed_at) or DEFAULT_DASH_VALUE
        alternate_day = NO if not order.payed_at else YES if order.created_at.day != order.payed_at.day else NO
        manual_charge = YES if order.extend and REASON in order.extend else NO

        data_row = [order_id, out_trade_no, third_id, \
                    created_at, mch_name, total_fee, \
                    payed_at, alternate_day, manual_charge]
        resp_items.append(data_row)
    return resp_items


@sql_wrapper
def delete_order(id):
    delete(PayOrder, id)


@sql_wrapper
def get_order_overview(parsed_dct):
    query = orm.session.query(PayOrder.channel_type, PayOrder.channel_id,
                              func.count(PayOrder), func.sum(PayOrder.total_fee))
    query = query.filter(generate_filter(parsed_dct, PayOrder))
    query = query.filter(PayOrder.status == PAY_STATUS.SUCC)
    query = query.group_by(PayOrder.channel_type).group_by(
        PayOrder.channel_id)
    resp = []
    chn_ids = []
    for r in query.all():
        # (type, count , sum)
        resp.append({
            "channel_type": r[0] or parsed_dct.get('channel_type', 0),
            "channel_id": r[1] or parsed_dct.get('channel_id', 0),
            "count": r[2],
            "total": float(r[3]) if r[3] is not None else 0,
        })
        chn_ids.append(r[1])
    return resp, chn_ids
